{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 达观杯2021"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E:\\兼职\\深度之眼\\比赛训练营\\21年8月-达观-风险事件标签识别\\ppts\\第一课-开营\\codes\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "%cd ../../"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载数据集，并切分train/dev"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载数据\n",
    "df_train = pd.read_csv(\"./datasets/phase_1/splits/fold_0/train.txt\")\n",
    "df_train.columns = [\"id\", \"text\", \"label\"]\n",
    "df_val = pd.read_csv(\"./datasets/phase_1/splits/fold_0/dev.txt\")\n",
    "df_val.columns = [\"id\", \"text\", \"label\"]\n",
    "df_test = pd.read_csv(\"./datasets/phase_1/splits/fold_0/test.txt\")\n",
    "df_test.columns = [\"id\", \"text\", ]\n",
    "\n",
    "# 构建词表\n",
    "charset = set()\n",
    "for text in df_train['text']:\n",
    "    for char in text.split(\" \"):\n",
    "        charset.add(char)\n",
    "id2char = ['OOV', '，', '。', '！', '？'] + list(charset)\n",
    "char2id = {id2char[i]: i for i in range(len(id2char))}\n",
    "\n",
    "# 标签集\n",
    "id2label = list(df_train['label'].unique())\n",
    "label2id = {id2label[i]: i for i in range(len(id2label))}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         [(None, 128)]             0         \n",
      "_________________________________________________________________\n",
      "embedding (Embedding)        (None, 128, 128)          402176    \n",
      "_________________________________________________________________\n",
      "bidirectional (Bidirectional (None, 128, 256)          263168    \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            (None, 32768)             0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 35)                1146915   \n",
      "=================================================================\n",
      "Total params: 1,812,259\n",
      "Trainable params: 1,812,259\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# 定义模型\n",
    "\n",
    "from tensorflow.keras.layers import *\n",
    "from tensorflow.keras.models import *\n",
    "\n",
    "MAX_LEN = 128\n",
    "input_layer = Input(shape=(MAX_LEN,))\n",
    "layer = Embedding(input_dim=len(id2char), output_dim=128)(input_layer)\n",
    "layer = Bidirectional(LSTM(128, return_sequences=True))(layer)\n",
    "layer = Flatten()(layer)  # [*, 128, 256] --> [*, 128 * 256]\n",
    "output_layer = Dense(len(id2label), activation='softmax')(layer)\n",
    "model = Model(inputs=input_layer, outputs=output_layer)\n",
    "model.summary()\n",
    "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 准备输入数据\n",
    "\n",
    "对训练集、验证集、测试集进行输入转换，构造模型输入。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
    "import numpy as np\n",
    "\n",
    "X_train, X_val, X_test = [], [], []\n",
    "y_train = np.zeros((len(df_train), len(id2label)), dtype=np.int8)\n",
    "y_val = np.zeros((len(df_val), len(id2label)), dtype=np.int8)\n",
    "\n",
    "for i in range(len(df_train)):\n",
    "    X_train.append([char2id[char] for char in df_train.loc[i, 'text'].split(\" \")])\n",
    "    y_train[i][label2id[df_train.loc[i, 'label']]] = 1\n",
    "for i in range(len(df_val)):\n",
    "    X_val.append([char2id[char] if char in char2id else 0 for char in df_val.loc[i, 'text'].split(\" \")])\n",
    "    y_val[i][label2id[df_val.loc[i, 'label']]] = 1\n",
    "for i in range(len(df_test)):\n",
    "    X_test.append([char2id[char] if char in char2id else 0 for char in df_test.loc[i, 'text'].split(\" \")])\n",
    "\n",
    "X_train = pad_sequences(X_train, maxlen=MAX_LEN, padding='post', truncating='post')\n",
    "X_val = pad_sequences(X_val, maxlen=MAX_LEN, padding='post', truncating='post')\n",
    "X_test = pad_sequences(X_test, maxlen=MAX_LEN, padding='post', truncating='post')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2797/2797 [==============================] - 52s 16ms/step - loss: 1.9545 - accuracy: 0.4706 - val_loss: 1.5315 - val_accuracy: 0.5727\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x16923977a90>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(x=X_train, y=y_train, validation_data=(X_val, y_val), epochs=1, batch_size=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 2 0 0]\n",
      "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "{\n",
      "  \"0__precision\": 0.8443579766536965,\n",
      "  \"0__recall\": 0.6977491961414791,\n",
      "  \"0__f1-score\": 0.7640845070422536,\n",
      "  \"0__support\": 311,\n",
      "  \"1__precision\": 0.547945205479452,\n",
      "  \"1__recall\": 0.4332129963898917,\n",
      "  \"1__f1-score\": 0.4838709677419355,\n",
      "  \"1__support\": 277,\n",
      "  \"2__precision\": 0.599483204134367,\n",
      "  \"2__recall\": 0.3642072213500785,\n",
      "  \"2__f1-score\": 0.453125,\n",
      "  \"2__support\": 637,\n",
      "  \"3__precision\": 0.6,\n",
      "  \"3__recall\": 0.6190476190476191,\n",
      "  \"3__f1-score\": 0.609375,\n",
      "  \"3__support\": 63,\n",
      "  \"4__precision\": 0.6,\n",
      "  \"4__recall\": 0.6097560975609756,\n",
      "  \"4__f1-score\": 0.6048387096774194,\n",
      "  \"4__support\": 123,\n",
      "  \"5__precision\": 0.35714285714285715,\n",
      "  \"5__recall\": 0.39473684210526316,\n",
      "  \"5__f1-score\": 0.37500000000000006,\n",
      "  \"5__support\": 38,\n",
      "  \"6__precision\": 0.23684210526315788,\n",
      "  \"6__recall\": 0.4090909090909091,\n",
      "  \"6__f1-score\": 0.3,\n",
      "  \"6__support\": 22,\n",
      "  \"7__precision\": 0.75,\n",
      "  \"7__recall\": 0.8541666666666666,\n",
      "  \"7__f1-score\": 0.7987012987012988,\n",
      "  \"7__support\": 144,\n",
      "  \"8__precision\": 0.4,\n",
      "  \"8__recall\": 0.3673469387755102,\n",
      "  \"8__f1-score\": 0.3829787234042553,\n",
      "  \"8__support\": 49,\n",
      "  \"9__precision\": 0.3611111111111111,\n",
      "  \"9__recall\": 0.5652173913043478,\n",
      "  \"9__f1-score\": 0.44067796610169496,\n",
      "  \"9__support\": 23,\n",
      "  \"10__precision\": 0.6218487394957983,\n",
      "  \"10__recall\": 0.6271186440677966,\n",
      "  \"10__f1-score\": 0.6244725738396625,\n",
      "  \"10__support\": 118,\n",
      "  \"11__precision\": 0.8888888888888888,\n",
      "  \"11__recall\": 0.8502415458937198,\n",
      "  \"11__f1-score\": 0.8691358024691357,\n",
      "  \"11__support\": 207,\n",
      "  \"12__precision\": 0.568,\n",
      "  \"12__recall\": 0.5182481751824818,\n",
      "  \"12__f1-score\": 0.5419847328244274,\n",
      "  \"12__support\": 137,\n",
      "  \"13__precision\": 0.1111111111111111,\n",
      "  \"13__recall\": 0.5,\n",
      "  \"13__f1-score\": 0.1818181818181818,\n",
      "  \"13__support\": 6,\n",
      "  \"14__precision\": 0.05555555555555555,\n",
      "  \"14__recall\": 0.2222222222222222,\n",
      "  \"14__f1-score\": 0.08888888888888888,\n",
      "  \"14__support\": 18,\n",
      "  \"15__precision\": 0.6081081081081081,\n",
      "  \"15__recall\": 0.6474820143884892,\n",
      "  \"15__f1-score\": 0.627177700348432,\n",
      "  \"15__support\": 139,\n",
      "  \"16__precision\": 0.4765625,\n",
      "  \"16__recall\": 0.5648148148148148,\n",
      "  \"16__f1-score\": 0.5169491525423728,\n",
      "  \"16__support\": 108,\n",
      "  \"17__precision\": 0.76875,\n",
      "  \"17__recall\": 0.7409638554216867,\n",
      "  \"17__f1-score\": 0.7546012269938649,\n",
      "  \"17__support\": 166,\n",
      "  \"18__precision\": 0.5625,\n",
      "  \"18__recall\": 0.7297297297297297,\n",
      "  \"18__f1-score\": 0.6352941176470588,\n",
      "  \"18__support\": 37,\n",
      "  \"19__precision\": 0.4745762711864407,\n",
      "  \"19__recall\": 0.5833333333333334,\n",
      "  \"19__f1-score\": 0.5233644859813085,\n",
      "  \"19__support\": 48,\n",
      "  \"20__precision\": 0.8026315789473685,\n",
      "  \"20__recall\": 0.8472222222222222,\n",
      "  \"20__f1-score\": 0.8243243243243243,\n",
      "  \"20__support\": 72,\n",
      "  \"21__precision\": 0.1875,\n",
      "  \"21__recall\": 0.2857142857142857,\n",
      "  \"21__f1-score\": 0.22641509433962265,\n",
      "  \"21__support\": 21,\n",
      "  \"22__precision\": 0.1,\n",
      "  \"22__recall\": 0.5,\n",
      "  \"22__f1-score\": 0.16666666666666669,\n",
      "  \"22__support\": 10,\n",
      "  \"23__precision\": 0.12280701754385964,\n",
      "  \"23__recall\": 0.3888888888888889,\n",
      "  \"23__f1-score\": 0.18666666666666665,\n",
      "  \"23__support\": 18,\n",
      "  \"24__precision\": 0.1875,\n",
      "  \"24__recall\": 1.0,\n",
      "  \"24__f1-score\": 0.3157894736842105,\n",
      "  \"24__support\": 6,\n",
      "  \"25__precision\": 0.0,\n",
      "  \"25__recall\": 0.0,\n",
      "  \"25__f1-score\": 0.0,\n",
      "  \"25__support\": 0,\n",
      "  \"26__precision\": 0.027777777777777776,\n",
      "  \"26__recall\": 0.25,\n",
      "  \"26__f1-score\": 0.049999999999999996,\n",
      "  \"26__support\": 4,\n",
      "  \"27__precision\": 0.0,\n",
      "  \"27__recall\": 0.0,\n",
      "  \"27__f1-score\": 0.0,\n",
      "  \"27__support\": 0,\n",
      "  \"28__precision\": 0.0,\n",
      "  \"28__recall\": 0.0,\n",
      "  \"28__f1-score\": 0.0,\n",
      "  \"28__support\": 0,\n",
      "  \"29__precision\": 0.5238095238095238,\n",
      "  \"29__recall\": 0.6470588235294118,\n",
      "  \"29__f1-score\": 0.5789473684210527,\n",
      "  \"29__support\": 17,\n",
      "  \"30__precision\": 0.0,\n",
      "  \"30__recall\": 0.0,\n",
      "  \"30__f1-score\": 0.0,\n",
      "  \"30__support\": 0,\n",
      "  \"31__precision\": 0.0,\n",
      "  \"31__recall\": 0.0,\n",
      "  \"31__f1-score\": 0.0,\n",
      "  \"31__support\": 0,\n",
      "  \"32__precision\": 0.0,\n",
      "  \"32__recall\": 0.0,\n",
      "  \"32__f1-score\": 0.0,\n",
      "  \"32__support\": 1,\n",
      "  \"33__precision\": 0.0,\n",
      "  \"33__recall\": 0.0,\n",
      "  \"33__f1-score\": 0.0,\n",
      "  \"33__support\": 0,\n",
      "  \"34__precision\": 0.0,\n",
      "  \"34__recall\": 0.0,\n",
      "  \"34__f1-score\": 0.0,\n",
      "  \"34__support\": 0,\n",
      "  \"accuracy\": 0.5726950354609929,\n",
      "  \"macro avg__precision\": 0.3538517009202592,\n",
      "  \"macro avg__recall\": 0.4347877266811949,\n",
      "  \"macro avg__f1-score\": 0.36928996086070665,\n",
      "  \"macro avg__support\": 2820,\n",
      "  \"weighted avg__precision\": 0.6317993238015933,\n",
      "  \"weighted avg__recall\": 0.5726950354609929,\n",
      "  \"weighted avg__f1-score\": 0.5909092724558893,\n",
      "  \"weighted avg__support\": 2820\n",
      "}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Users\\Administrator\\anaconda3\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1245: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "D:\\Users\\Administrator\\anaconda3\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1245: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "D:\\Users\\Administrator\\anaconda3\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1245: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "y_val_pred = model.predict(X_val).argmax(axis=-1)\n",
    "print(y_val_pred[: 20])\n",
    "y_val = []\n",
    "for i in range(len(df_val)):\n",
    "    y_val.append(label2id[df_val.loc[i, 'label']])\n",
    "y_val = [int(w) for w in y_val]\n",
    "print(y_val[: 20])\n",
    "\n",
    "from sklearn.metrics import classification_report\n",
    "results = {}\n",
    "classification_report_dict = classification_report(y_val_pred, y_val, output_dict=True)\n",
    "for key0, val0 in classification_report_dict.items():\n",
    "    if isinstance(val0, dict):\n",
    "        for key1, val1 in val0.items():\n",
    "            results[key0 + \"__\" + key1] = val1\n",
    "\n",
    "    else:\n",
    "        results[key0] = val0\n",
    "\n",
    "import json\n",
    "print(json.dumps(results, indent=2, ensure_ascii=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 输出预测结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = model.predict(X_test).argmax(axis=-1)\n",
    "pred_labels = [id2label[i] for i in y_pred]\n",
    "pd.DataFrame({\"id\": df_test['id'], \"label\": pred_labels}).to_csv(\"submission.csv\", index=False)\n",
    "\n",
    "\n",
    "\n",
    "# 提交结果：\n",
    "# 0.36730954652"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
